import networkx as nx
import numpy as np
import pickle
import collections


def get_minmax_each_att(test_graph_path):
    test_graph = nx.read_gpickle(test_graph_path)
    all_typed_values = {}
    for u in test_graph.nodes():
        if isinstance(u, tuple):
            if u[1] not in all_typed_values:
                all_typed_values[u[1]] = []
            all_typed_values[u[1]].append(u[0])

    num_min_max={}
    for i in range(len(all_typed_values)):
        numlist=all_typed_values[i]
        min=np.min(numlist)
        max=np.max(numlist)
        #min=np.percentile(numlist, 3)
        #max=np.percentile(numlist,97)
        
        num_min_max[i]={}
        num_min_max[i]["min"]=min
        num_min_max[i]["max"]=max
    return num_min_max



if __name__ =="__main__":

    test_graph_path="/home/pxs/projects/paper_nrn_new/processing_entityandnum/FB15K_test_with_units.pkl"
    #获得minmax信息
    num_min_max=get_minmax_each_att(test_graph_path)
    
    #三元组路径
    split_path="/home/pxs/projects/paper_nrn_new/processing_task2/FB15K_edge_splid.pkl"
    
    print("Load Train Graph " + split_path)
    with open(split_path,"rb") as file:
        (train_triples_entity,valid_triples_entity,test_triples_entity,train_triples_num,valid_triples_num,test_triples_num)=pickle.load(file)
    
    
    lhs_dict=collections.defaultdict(list)
    rhs_dict=collections.defaultdict(list)
    ##实体##
    all_entity=set()
    all_relation=set()
    #训练三元组
    ent_train=[]
    for tri in train_triples_entity:
        h=int(tri[0])
        r=int(list(tri[2].keys())[0])
        t=int(tri[1])
        ent_train.append([h,r,t])
        rhs_dict[(h,r)].append(t)
        lhs_dict[(t,r)].append(h)
        all_entity.add(h)
        all_entity.add(t)
        all_relation.add(r)
    
    ##验证三元组
    ent_valid=[]
    for tri in valid_triples_entity:
        h=int(tri[0])
        r=int(list(tri[2].keys())[0])
        t=int(tri[1])
        ent_valid.append([h,r,t])
        all_entity.add(h)
        all_entity.add(t)
        all_relation.add(r)
    
    ##测试三元组
    ent_test=[]
    for tri in test_triples_entity:
        h=int(tri[0])
        r=int(list(tri[2].keys())[0])
        t=int(tri[1])
        ent_test.append([h,r,t])
        all_entity.add(h)
        all_entity.add(t)
        all_relation.add(r)
        
    entity_num=len(all_entity)
    relation_num=len(all_relation)
    
    #为训练集添加错误的实体。
    """    tri_with_corru_for_train=[]
    for id, tri in enumerate(ent_train):
        h=tri[0]
        r=tri[1]
        t=tri[2]
        
        negative_sample_list = []
        negative_sample_size = 0
        while negative_sample_size<10:
            negative_sample = np.random.randint(entity_num, size=10*2)
            true_answers=rhs_dict[(h,r)]
            mask = np.in1d(
                negative_sample,
                true_answers,
                assume_unique=True,
                invert=True
            )
            negative_sample = negative_sample[mask]
            negative_sample_list.extend(list(negative_sample))
            negative_sample_size += negative_sample.size
        
        this_data=[h,r,t]+negative_sample_list[:10]
        tri_with_corru_for_train.append(this_data)
        if id%10000==0:
            print(id)"""
    
    ##数值
    
    ##每个属性存在的实体
    each_att_entity_train=collections.defaultdict(set)
    each_att_entity_test=collections.defaultdict(set)
    
    
    #每个实体的属性（训练集中）
    entity_att_dict=collections.defaultdict(set)
    #所有的属性（训练验证测试）
    all_att=set()
    #训练三元组
    att_train=[]
    for tri in train_triples_num:
        h=int(tri[0])
        r=int(list(tri[2].keys())[0])
        each_att_entity_train[r].add(h)
        ##对特定的节点添加属性的边
        entity_att_dict[h].add(r)
        num=tri[1][0]
        unit=tri[1][1]
        t=(num-num_min_max[unit]["min"])/(num_min_max[unit]["max"]-num_min_max[unit]["min"])
        att_train.append([h,r,t])
        all_att.add(r)
    
    #虚拟节点添加特定属性的边
    dummy_edges=[]
    for key,value in entity_att_dict.items():
        for att in value:
            h=entity_num
            r=att+relation_num
            t=key
            dummy_edges.append([h,r,t])
            rhs_dict[(h,r)].append(t)
            lhs_dict[(t,r)].append(h)
            
            h=key
            r=att+relation_num+1
            t=entity_num
            dummy_edges.append([h,r,t])
            rhs_dict[(h,r)].append(t)
            lhs_dict[(t,r)].append(h)
        
        


    #验证三元组
    att_valid=[]
    for tri in valid_triples_num:
        h=int(tri[0])
        r=int(list(tri[2].keys())[0])
        #每个属性有的实体，测试集
        each_att_entity_test[r].add(h)
        num=tri[1][0]
        unit=tri[1][1]
        t=(num-num_min_max[unit]["min"])/(num_min_max[unit]["max"]-num_min_max[unit]["min"])
        att_valid.append([h,r,t])
        all_att.add(r)
        
    
    #测试三元组
    att_test=[]
    for tri in test_triples_num:
        h=int(tri[0])
        r=int(list(tri[2].keys())[0])
        #每个属性有的实体，测试集
        each_att_entity_test[r].add(h)
        num=tri[1][0]
        unit=tri[1][1]
        t=(num-num_min_max[unit]["min"])/(num_min_max[unit]["max"]-num_min_max[unit]["min"])
        att_test.append([h,r,t])
        all_att.add(r)
    att_num=len(all_att)
    
    
    #训练数据集加上虚拟边
    ent_train=ent_train+dummy_edges
    all_entity.add(len(all_entity))
    all_relation=all_relation|set([relation_num+i for i in range(2*att_num)])
    
    preprocessed_data=(ent_train,ent_valid,ent_test,att_train,att_valid,att_test,rhs_dict,lhs_dict,all_entity,all_relation,all_att,each_att_entity_train,each_att_entity_test)
    with open("preprocessed_data_for_linepredictor.pkl","wb") as file:
        pickle.dump(preprocessed_data,file)
    print("work done!")